{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building a GPT language model... from scratch!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"gpt/harry-potter.txt\", \"r\", encoding=\"utf-8\") as f:\n",
    "    text = f.read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6250286"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CHAPTER ONE\n",
      "THE BOY WHO LIVED\n",
      "Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say\n",
      "that they were perfectly normal, thank you very much. They were the last\n",
      "people you'd expect to be involved in anything strange or mysterious,\n",
      "because they just didn't hold with such nonsense.\n",
      "Mr. Dursley was the director of a firm called Grunnings, which made\n",
      "drills. He was a big, beefy man with hardly any neck, although he did\n",
      "have a very large mustache. Mrs. Dursley was thin and blonde and had\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(text[:500])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\f !\"&'()*,-./0123456789:;<=>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{}~¦«»éü–—‘’“”•…\n",
      "104\n"
     ]
    }
   ],
   "source": [
    "chars = sorted(list(set(text)))\n",
    "vocab_size = len(chars)\n",
    "\n",
    "print(\"\".join(chars))\n",
    "print(vocab_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "stoi = {ch:i for i, ch in enumerate(chars)}\n",
    "itos = {i:ch for i, ch in enumerate(chars)}\n",
    "encode = lambda s: [stoi[c] for c in s]\n",
    "decode = lambda l: \"\".join([itos[i] for i in l])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[36, 76, 76, 65, 2, 74, 76, 79, 75, 70, 75, 68, 3]\n",
      "Good morning!\n"
     ]
    }
   ],
   "source": [
    "print(encode(\"Good morning!\"))\n",
    "print(decode(encode(\"Good morning!\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/glouppe/anaconda3/envs/torch-gpu/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([6250286]), torch.int64)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "data = torch.tensor(encode(text))\n",
    "data.shape, data.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([32, 37, 30, 45, 49, 34, 47,  2, 44, 43, 34,  0, 49, 37, 34,  2, 31, 44,\n",
       "        54,  2, 52, 37, 44,  2, 41, 38, 51, 34, 33,  0, 42, 79, 12,  2, 62, 75,\n",
       "        65,  2, 42, 79, 80, 12,  2, 33, 82, 79, 80, 73, 66, 86, 10,  2, 76, 67,\n",
       "         2, 75, 82, 74, 63, 66, 79,  2, 67, 76, 82, 79, 10,  2, 45, 79, 70, 83,\n",
       "        66, 81,  2, 33, 79, 70, 83, 66, 10,  2, 84, 66, 79, 66,  2, 77, 79, 76,\n",
       "        82, 65,  2, 81, 76,  2, 80, 62, 86,  0, 81, 69, 62, 81,  2, 81, 69, 66,\n",
       "        86,  2, 84, 66, 79, 66,  2, 77, 66, 79, 67, 66, 64, 81, 73, 86,  2, 75,\n",
       "        76, 79, 74, 62, 73, 10,  2, 81, 69, 62, 75, 72,  2, 86, 76, 82,  2, 83,\n",
       "        66, 79, 86,  2, 74, 82, 64, 69, 12,  2, 49, 69, 66, 86,  2, 84, 66, 79,\n",
       "        66,  2, 81, 69, 66,  2, 73, 62, 80, 81,  0, 77, 66, 76, 77, 73, 66,  2,\n",
       "        86, 76, 82,  6, 65,  2, 66, 85, 77, 66, 64, 81,  2, 81, 76,  2, 63, 66,\n",
       "         2, 70, 75, 83, 76, 73, 83, 66, 65,  2, 70, 75,  2, 62, 75, 86, 81, 69,\n",
       "        70, 75, 68,  2, 80, 81, 79, 62, 75, 68, 66,  2, 76, 79,  2, 74, 86, 80,\n",
       "        81, 66, 79, 70, 76, 82, 80, 10,  0, 63, 66, 64, 62, 82, 80, 66,  2, 81,\n",
       "        69, 66, 86,  2, 71, 82, 80, 81,  2, 65, 70, 65, 75,  6, 81,  2, 69, 76,\n",
       "        73, 65,  2, 84, 70, 81, 69,  2, 80, 82, 64, 69,  2, 75, 76, 75, 80, 66,\n",
       "        75, 80, 66, 12,  0, 42, 79, 12,  2, 33, 82, 79, 80, 73, 66, 86,  2, 84,\n",
       "        62, 80,  2, 81, 69, 66,  2, 65, 70, 79, 66, 64, 81, 76, 79,  2, 76, 67,\n",
       "         2, 62,  2, 67, 70, 79, 74,  2, 64, 62, 73, 73, 66, 65,  2, 36, 79, 82,\n",
       "        75, 75, 70, 75, 68, 80, 10,  2, 84, 69, 70, 64, 69,  2, 74, 62, 65, 66,\n",
       "         0, 65, 79, 70, 73, 73, 80, 12,  2, 37, 66,  2, 84, 62, 80,  2, 62,  2,\n",
       "        63, 70, 68, 10,  2, 63, 66, 66, 67, 86,  2, 74, 62, 75,  2, 84, 70, 81,\n",
       "        69,  2, 69, 62, 79, 65, 73, 86,  2, 62, 75, 86,  2, 75, 66, 64, 72, 10,\n",
       "         2, 62, 73, 81, 69, 76, 82, 68, 69,  2, 69, 66,  2, 65, 70, 65,  0, 69,\n",
       "        62, 83, 66,  2, 62,  2, 83, 66, 79, 86,  2, 73, 62, 79, 68, 66,  2, 74,\n",
       "        82, 80, 81, 62, 64, 69, 66, 12,  2, 42, 79, 80, 12,  2, 33, 82, 79, 80,\n",
       "        73, 66, 86,  2, 84, 62, 80,  2, 81, 69, 70, 75,  2, 62, 75, 65,  2, 63,\n",
       "        73, 76, 75, 65, 66,  2, 62, 75, 65,  2, 69, 62, 65,  0])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[:500]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = int(0.9 * len(data))\n",
    "train_data = data[:n]\n",
    "validation_data = data[n:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([32, 37, 30, 45, 49, 34, 47,  2, 44])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "block_size = 8\n",
    "train_data[:block_size+1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prompt = tensor([32]), target = 37\n",
      "prompt = tensor([32, 37]), target = 30\n",
      "prompt = tensor([32, 37, 30]), target = 45\n",
      "prompt = tensor([32, 37, 30, 45]), target = 49\n",
      "prompt = tensor([32, 37, 30, 45, 49]), target = 34\n",
      "prompt = tensor([32, 37, 30, 45, 49, 34]), target = 47\n",
      "prompt = tensor([32, 37, 30, 45, 49, 34, 47]), target = 2\n",
      "prompt = tensor([32, 37, 30, 45, 49, 34, 47,  2]), target = 44\n"
     ]
    }
   ],
   "source": [
    "x = train_data[:block_size]\n",
    "y = train_data[1:block_size+1]\n",
    "\n",
    "for t in range(block_size):\n",
    "    prompt = x[:t+1]\n",
    "    target = y[t]\n",
    "    print(f\"prompt = {prompt}, target = {target}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[77, 79, 76, 65, 82, 64, 66, 65],\n",
      "        [ 2, 73, 76, 76, 72, 70, 75,  6],\n",
      "        [ 2, 33, 82, 74, 63, 73, 66, 11],\n",
      "        [76, 81, 81, 80,  3,  4,  0,  4]])\n",
      "tensor([[79, 76, 65, 82, 64, 66, 65,  2],\n",
      "        [73, 76, 76, 72, 70, 75,  6,  2],\n",
      "        [33, 82, 74, 63, 73, 66, 11, 65],\n",
      "        [81, 81, 80,  3,  4,  0,  4, 30]])\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(42)\n",
    "batch_size = 4 \n",
    "block_size = 8\n",
    "\n",
    "def get_batch(data):\n",
    "    ix = torch.randint(len(data) - block_size, (batch_size, ))\n",
    "    x = torch.stack([data[i:i+block_size] for i in ix])\n",
    "    y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
    "    return x, y\n",
    "\n",
    "xb, yb = get_batch(train_data)\n",
    "print(xb)\n",
    "print(yb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prompt = tensor([77]), target = 79\n",
      "prompt = tensor([77, 79]), target = 76\n",
      "prompt = tensor([77, 79, 76]), target = 65\n",
      "prompt = tensor([77, 79, 76, 65]), target = 82\n",
      "prompt = tensor([77, 79, 76, 65, 82]), target = 64\n",
      "prompt = tensor([77, 79, 76, 65, 82, 64]), target = 66\n",
      "prompt = tensor([77, 79, 76, 65, 82, 64, 66]), target = 65\n",
      "prompt = tensor([77, 79, 76, 65, 82, 64, 66, 65]), target = 2\n",
      "prompt = tensor([2]), target = 73\n",
      "prompt = tensor([ 2, 73]), target = 76\n",
      "prompt = tensor([ 2, 73, 76]), target = 76\n",
      "prompt = tensor([ 2, 73, 76, 76]), target = 72\n",
      "prompt = tensor([ 2, 73, 76, 76, 72]), target = 70\n",
      "prompt = tensor([ 2, 73, 76, 76, 72, 70]), target = 75\n",
      "prompt = tensor([ 2, 73, 76, 76, 72, 70, 75]), target = 6\n",
      "prompt = tensor([ 2, 73, 76, 76, 72, 70, 75,  6]), target = 2\n",
      "prompt = tensor([2]), target = 33\n",
      "prompt = tensor([ 2, 33]), target = 82\n",
      "prompt = tensor([ 2, 33, 82]), target = 74\n",
      "prompt = tensor([ 2, 33, 82, 74]), target = 63\n",
      "prompt = tensor([ 2, 33, 82, 74, 63]), target = 73\n",
      "prompt = tensor([ 2, 33, 82, 74, 63, 73]), target = 66\n",
      "prompt = tensor([ 2, 33, 82, 74, 63, 73, 66]), target = 11\n",
      "prompt = tensor([ 2, 33, 82, 74, 63, 73, 66, 11]), target = 65\n",
      "prompt = tensor([76]), target = 81\n",
      "prompt = tensor([76, 81]), target = 81\n",
      "prompt = tensor([76, 81, 81]), target = 80\n",
      "prompt = tensor([76, 81, 81, 80]), target = 3\n",
      "prompt = tensor([76, 81, 81, 80,  3]), target = 4\n",
      "prompt = tensor([76, 81, 81, 80,  3,  4]), target = 0\n",
      "prompt = tensor([76, 81, 81, 80,  3,  4,  0]), target = 4\n",
      "prompt = tensor([76, 81, 81, 80,  3,  4,  0,  4]), target = 30\n"
     ]
    }
   ],
   "source": [
    "for b in range(batch_size):\n",
    "    for t in range(block_size):\n",
    "        prompt = xb[b, :t+1]\n",
    "        target = yb[b, t]\n",
    "        print(f\"prompt = {prompt}, target = {target}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bigram language model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn \n",
    "from torch.nn import functional as F \n",
    "torch.manual_seed(42)\n",
    "\n",
    "class BigramLanguageModel(nn.Module):\n",
    "    def __init__(self, vocab_size):\n",
    "        super().__init__()\n",
    "        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
    "\n",
    "    def forward(self, x): # (B, T)\n",
    "        logits = self.token_embedding_table(x) # (B, T, V)\n",
    "        return logits \n",
    "    \n",
    "    def generate(self, x, max_new_tokens):  # x is (B, T)\n",
    "        for _ in range(max_new_tokens):\n",
    "            logits = self(x)  # (B, T, V)\n",
    "            logits = logits[:, -1, :]  # (B, V)\n",
    "            probs = F.softmax(logits, dim=1)\n",
    "            x_next = torch.multinomial(probs, num_samples=1)\n",
    "            x = torch.cat([x, x_next], dim=1)\n",
    "        return x \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4, 8]), torch.Size([4, 8, 104]))"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lm = BigramLanguageModel(vocab_size)\n",
    "logits = lm(xb)\n",
    "xb.shape, logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.8146, -1.0212, -0.4949,  ...,  1.2461, -2.3065, -1.2869],\n",
       "         [-1.3412,  0.3424,  0.1963,  ...,  1.0125, -0.7147,  0.3446],\n",
       "         [ 0.5997, -0.3390,  0.1549,  ..., -0.2335, -0.7175, -2.2448],\n",
       "         ...,\n",
       "         [ 0.6960,  0.3888, -2.5335,  ..., -0.5347, -0.1319, -1.1636],\n",
       "         [-0.8146, -1.0212, -0.4949,  ...,  1.2461, -2.3065, -1.2869],\n",
       "         [ 0.1773,  0.9313, -1.1519,  ...,  0.6461,  1.0271,  0.9107]],\n",
       "\n",
       "        [[ 0.5997, -0.3390,  0.1549,  ..., -0.2335, -0.7175, -2.2448],\n",
       "         [ 1.6601, -0.5517, -0.3104,  ..., -0.7300, -1.4113,  0.3488],\n",
       "         [-0.8146, -1.0212, -0.4949,  ...,  1.2461, -2.3065, -1.2869],\n",
       "         ...,\n",
       "         [-1.3938,  0.8466, -1.7191,  ..., -0.5860,  2.0284, -0.1151],\n",
       "         [-0.8146, -1.0212, -0.4949,  ...,  1.2461, -2.3065, -1.2869],\n",
       "         [ 0.5530,  1.2586,  0.2317,  ...,  0.6008,  0.7986, -1.3825]],\n",
       "\n",
       "        [[ 1.6601, -0.5517, -0.3104,  ..., -0.7300, -1.4113,  0.3488],\n",
       "         [-0.6870,  0.3154, -1.2174,  ...,  0.7114, -0.2421,  0.9910],\n",
       "         [ 1.6601, -0.5517, -0.3104,  ..., -0.7300, -1.4113,  0.3488],\n",
       "         ...,\n",
       "         [ 0.5525,  0.8857, -1.3390,  ..., -0.4260, -0.9261,  0.3349],\n",
       "         [ 0.1773,  0.9313, -1.1519,  ...,  0.6461,  1.0271,  0.9107],\n",
       "         [-0.8146, -1.0212, -0.4949,  ...,  1.2461, -2.3065, -1.2869]],\n",
       "\n",
       "        [[ 0.5530,  1.2586,  0.2317,  ...,  0.6008,  0.7986, -1.3825],\n",
       "         [-1.3412,  0.3424,  0.1963,  ...,  1.0125, -0.7147,  0.3446],\n",
       "         [-0.8146, -1.0212, -0.4949,  ...,  1.2461, -2.3065, -1.2869],\n",
       "         ...,\n",
       "         [ 1.0420,  1.1547,  0.3103,  ...,  0.2814, -0.5159,  1.5220],\n",
       "         [ 1.0249, -0.6914, -0.5325,  ...,  1.4238,  0.8160,  0.2655],\n",
       "         [ 1.6601, -0.5517, -0.3104,  ..., -0.7300, -1.4113,  0.3488]]],\n",
       "       grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "eN\"_h‘xr*Hw2z^aD_h‘xr*Hw2z^aD_h‘xr*Hw2z^aD_h‘xr*Hw2z^aD_h‘xr*Hw2z^aD_h‘xr*Hw2z^aD_h‘xr*Hw2z^aD_h‘xr*\n"
     ]
    }
   ],
   "source": [
    "print(decode(lm.generate(torch.zeros((1, 1), dtype=torch.int64), max_new_tokens=100).squeeze().tolist()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(lm.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3217525482177734\n"
     ]
    }
   ],
   "source": [
    "for step in range(1000):\n",
    "    xb, yb = get_batch(train_data)\n",
    "\n",
    "    logits = lm(xb)\n",
    "    B, T, V = logits.shape \n",
    "    logits = logits.view(B*T, V)\n",
    "    targets = yb.view(B*T)\n",
    "    loss = F.cross_entropy(logits, targets)\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "print(loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n'"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decode(lm.generate(torch.zeros((1, 1), dtype=torch.int64), max_new_tokens=100).squeeze().tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Self-attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 8, 2])"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "B, T, C = 4, 8, 2\n",
    "x = torch.randn(B, T, C)\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "# xbow[b, t] = mean_{i <= t} x[b, i]\n",
    "\n",
    "# v1: for loop\n",
    "xbow = torch.zeros((B, T, C))\n",
    "for b in range(B):\n",
    "    for t in range(T):\n",
    "        xprev = x[b, :t+1]  # (t, C)\n",
    "        xbow[b, t] = torch.mean(xprev, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5783,  1.4785],\n",
       "        [-1.3545,  0.0231],\n",
       "        [ 0.6528, -0.1086],\n",
       "        [ 0.9647,  1.5781],\n",
       "        [ 1.2073,  0.6106],\n",
       "        [-0.0396, -0.6555],\n",
       "        [ 1.0122, -0.4351],\n",
       "        [-1.3594,  0.9402]])"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5783,  1.4785],\n",
       "        [-0.9664,  0.7508],\n",
       "        [-0.4267,  0.4644],\n",
       "        [-0.0788,  0.7428],\n",
       "        [ 0.1784,  0.7163],\n",
       "        [ 0.1421,  0.4877],\n",
       "        [ 0.2664,  0.3559],\n",
       "        [ 0.0632,  0.4289]])"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xbow[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# v2: matmul\n",
    "wei = torch.tril(torch.ones(T, T))\n",
    "wei = wei / wei.sum(dim=1, keepdim=True)\n",
    "xbow2 = wei @ x\n",
    "torch.allclose(xbow, xbow2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a = \n",
      "tensor([[1.0000, 0.0000, 0.0000],\n",
      "        [0.5000, 0.5000, 0.0000],\n",
      "        [0.3333, 0.3333, 0.3333]])\n",
      "b = \n",
      "tensor([[2., 7.],\n",
      "        [6., 4.],\n",
      "        [6., 5.]])\n",
      "c = a@b = \n",
      "tensor([[2.0000, 7.0000],\n",
      "        [4.0000, 5.5000],\n",
      "        [4.6667, 5.3333]])\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(42)\n",
    "a = torch.tril(torch.ones(3, 3))\n",
    "a = a / torch.sum(a, dim=1, keepdim=True)\n",
    "b = torch.randint(0, 10, (3, 2)).float()\n",
    "c = a @ b \n",
    "print(f\"a = \\n{a}\")\n",
    "print(f\"b = \\n{b}\")\n",
    "print(f\"c = a@b = \\n{c}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# v3: softmax\n",
    "tril = torch.tril(torch.ones(T, T))\n",
    "wei = torch.zeros((T, T))\n",
    "wei = wei.masked_fill(tril == 0, float(\"-inf\"))\n",
    "wei = F.softmax(wei, dim=-1)\n",
    "xbow3 = wei @ x\n",
    "torch.allclose(xbow, xbow3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 8, 32])"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# v4: data-dependent weighted sum\n",
    "torch.manual_seed(42)\n",
    "B, T, C = 4, 8, 32\n",
    "x = torch.randn(B, T, C)\n",
    "\n",
    "# single head\n",
    "head_size = 16\n",
    "key = nn.Linear(C, head_size, bias=False)\n",
    "query = nn.Linear(C, head_size, bias=False)\n",
    "k = key(x)  # (B, T, H)\n",
    "q = query(x)  # (B, T, H)\n",
    "wei = q @ k.transpose(-2, -1)  # (B, T, H) @ (B, H, T) -> (B, T, T)\n",
    "\n",
    "tril = torch.tril(torch.ones(T, T))\n",
    "\n",
    "#wei = torch.zeros((T, T))\n",
    "\n",
    "wei = wei.masked_fill(tril == 0, float(\"-inf\"))\n",
    "wei = F.softmax(wei, dim=-1)\n",
    "out = wei @ x\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.1905, 0.8095, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.3742, 0.0568, 0.5690, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.1288, 0.3380, 0.1376, 0.3956, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.4311, 0.0841, 0.0582, 0.3049, 0.1217, 0.0000, 0.0000, 0.0000],\n",
       "        [0.0537, 0.3205, 0.0694, 0.2404, 0.2568, 0.0592, 0.0000, 0.0000],\n",
       "        [0.3396, 0.0149, 0.5165, 0.0180, 0.0658, 0.0080, 0.0373, 0.0000],\n",
       "        [0.0165, 0.0375, 0.0144, 0.1120, 0.0332, 0.4069, 0.3136, 0.0660]],\n",
       "       grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wei[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 8, 16])"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# v5: data-dependent weighted sum\n",
    "torch.manual_seed(42)\n",
    "B, T, C = 4, 8, 32\n",
    "x = torch.randn(B, T, C)\n",
    "\n",
    "# single head\n",
    "head_size = 16\n",
    "key = nn.Linear(C, head_size, bias=False)\n",
    "query = nn.Linear(C, head_size, bias=False)\n",
    "value = nn.Linear(C, head_size, bias=False)\n",
    "k = key(x)  # (B, T, H)\n",
    "q = query(x)  # (B, T, H)\n",
    "wei = q @ k.transpose(-2, -1)  # (B, T, H) @ (B, H, T) -> (B, T, T)\n",
    "\n",
    "tril = torch.tril(torch.ones(T, T))\n",
    "wei = wei.masked_fill(tril == 0, float(\"-inf\"))\n",
    "wei = F.softmax(wei, dim=-1)\n",
    "v = value(x)\n",
    "out = wei @ v\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(1.0346), tensor(0.9618), tensor(15.1129))"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# variance\n",
    "k = torch.randn(B, T, head_size)\n",
    "q = torch.randn(B, T, head_size)\n",
    "wei = q @ k.transpose(-2, -1)\n",
    "k.var(), q.var(), wei.var()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(1.0346), tensor(0.9618), tensor(0.9446))"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wei = q @ k.transpose(-2, -1) * head_size ** -0.5\n",
    "k.var(), q.var(), wei.var()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]) * 8, dim=-1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch-gpu]",
   "language": "python",
   "name": "conda-env-torch-gpu-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
